{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from IPython import display\n",
    "import numpy as np\n",
    "import matplotlib.pylab as plt\n",
    "import pandas as pd\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from matplotlib_inline import backend_inline\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import time\n",
    "import random\n",
    "import math\n",
    "\n",
    "\n",
    "def use_svg_display():\n",
    "    \"\"\"使用svg展示图片\"\"\"\n",
    "    backend_inline.set_matplotlib_formats('svg')\n",
    "\n",
    "\n",
    "def set_figsize(figsize=(3.5, 2.5)):\n",
    "    \"\"\"设置图表大小\"\"\"\n",
    "    use_svg_display()\n",
    "    plt.rcParams['figure.figsize'] = figsize\n",
    "\n",
    "\n",
    "def show_images(imgs, num_rows, num_cols, titles=None, scale=2):\n",
    "    \"\"\"展示图片列表\"\"\"\n",
    "    figsize = (num_cols * scale, num_rows * scale)\n",
    "    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)\n",
    "    axes = axes.flatten()\n",
    "    for i, (ax, img) in enumerate(zip(axes, imgs)):\n",
    "        if torch.is_tensor(img):\n",
    "            # 图片张量\n",
    "            ax.imshow(img.numpy())\n",
    "        else:\n",
    "            # PIL图片\n",
    "            ax.imshow(img)\n",
    "        ax.axes.get_xaxis().set_visible(False)\n",
    "        ax.axes.get_yaxis().set_visible(False)\n",
    "        if titles:\n",
    "            ax.set_title(titles[i])\n",
    "    return axes\n",
    "\n",
    "\n",
    "def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):\n",
    "    \"\"\"设置坐标轴\"\"\"\n",
    "    axes.set_xlabel(xlabel)\n",
    "    axes.set_ylabel(ylabel)\n",
    "    axes.set_xscale(xscale)\n",
    "    axes.set_yscale(yscale)\n",
    "    axes.set_xlim(xlim)\n",
    "    axes.set_ylim(ylim)\n",
    "    if legend:\n",
    "        axes.legend(legend)\n",
    "    axes.grid()\n",
    "\n",
    "\n",
    "def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
    "         ylim=None, xscale='linear', yscale='linear',\n",
    "         fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):\n",
    "    \"\"\"绘制数据点\"\"\"\n",
    "    if legend is None:\n",
    "        legend = []\n",
    "\n",
    "    set_figsize(figsize)\n",
    "    axes = axes if axes else plt.gca()\n",
    "\n",
    "    # 如果X有一个轴，输出True\n",
    "    def has_one_axis(X):\n",
    "        return (hasattr(X, \"ndim\") and X.ndim == 1 or isinstance(X, list)\n",
    "                and not hasattr(X[0], \"__len__\"))\n",
    "\n",
    "    if has_one_axis(X):\n",
    "        X = [X]\n",
    "    if Y is None:\n",
    "        X, Y = [[]] * len(X), X\n",
    "    elif has_one_axis(Y):\n",
    "        Y = [Y]\n",
    "    if len(X) != len(Y):\n",
    "        X = X * len(Y)\n",
    "    axes.cla()\n",
    "    for x, y, fmt in zip(X, Y, fmts):\n",
    "        if len(x):\n",
    "            axes.plot(x, y, fmt)\n",
    "        else:\n",
    "            axes.plot(y, fmt)\n",
    "    set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n",
    "\n",
    "\n",
    "def load_array(data_arrays, batch_size, is_train=True):\n",
    "    \"\"\"构造一个PyTorch数据迭代器\n",
    "\n",
    "    Defined in :numref:`sec_linear_concise`\"\"\"\n",
    "    dataset = TensorDataset(*data_arrays)\n",
    "    return DataLoader(dataset, batch_size, shuffle=is_train)\n",
    "\n",
    "\n",
    "class Timer():\n",
    "    def __init__(self):\n",
    "        \"\"\"构造函数，创建一个存储运行时间的列表\"\"\"\n",
    "        self.times = []\n",
    "        self.start()\n",
    "\n",
    "    def start(self):\n",
    "        \"\"\"开始计时\"\"\"\n",
    "        self.tik = time.time()\n",
    "\n",
    "    def stop(self):\n",
    "        \"\"\"停止计时\"\"\"\n",
    "        self.times.append(time.time()-self.tik)\n",
    "        return self.times[-1]\n",
    "\n",
    "    def sum(self):\n",
    "        \"\"\"返回累计运行时间\"\"\"\n",
    "        return np.sum(self.times)\n",
    "\n",
    "    def avg(self):\n",
    "        \"\"\"返回平均运行时间\"\"\"\n",
    "        return np.mean(self.times)\n",
    "\n",
    "    def cumsum(self):\n",
    "        \"\"\"返回累计运行时间\"\"\"\n",
    "        return np.cumsum(self.times)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 线性回归重要手撸\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 线性回归\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_data(w, b, n):\n",
    "    \"\"\"根据y=W^T * X + b + 高斯噪声\"\"\"\n",
    "    X = torch.normal(0, 1, size=(n, len(w)))\n",
    "    y = torch.matmul(X, w)+b\n",
    "    y += torch.normal(0, 0.01, size=y.shape)\n",
    "    return X, y.reshape(-1, 1)\n",
    "\n",
    "\n",
    "def random_date(x, y, batchsize=16):\n",
    "    \"\"\"随机取出数据，按照每一轮取batchsize\"\"\"\n",
    "    pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 感知向量机重要手撸\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
